#!/usr/bin/env python
# -*- coding:utf-8 -*-


import asyncio
from datetime import datetime
import re
import time
from typing import Text
import libs.request as request
import libs.tools as tools
from libs.errhandle import handler_exception
import libs.printer as printer
import traceback
import itertools
try:
    import aiohttp
except ImportError as e:
    printer.warn("请使用pip安装aiohttp=3.7.4")

comp = {}

class broute(object):

    def __init__(
                self,
                url,
                words,
                verb,
                limit,
                delay,
                headers,
                data,
                timeout,
                hc,sc,ht,st,
                tag,orireq,record,mode):
        self._url = url
        self._words = words
        self._verb = verb
        self._limit = limit
        self._delay = delay
        self._headers = headers
        self._data = data
        self._timeout = timeout
        self._hc = hc
        self._sc = sc
        self._ht = ht
        self._st = st
        self._tag = tag
        self._orireq = orireq
        self._record = record
        self._mode = mode
    
    def rebuild_request(self, word, url):
        headers = {k : v for (k,v) in self._headers.items()}
        tools.replace_tag(headers,self._tag,word)
        data = ""
        if len(self._data) > 0:
            data = self._data.replace(self._tag, word)
        nurl = url.replace(self._tag, word)
        return request.Request(url=nurl, headers=headers, data=data, verb=self._verb, timeout=self._timeout,word=word,sign=False)
        
    def original_request(self):
        url = self._orireq['url']
        headers = self._orireq['headers']
        if self._orireq['params']:
            data = self._orireq['params']
        else:
            data = ''
        return request.Request(url=url, headers=headers, data=data, verb=self._verb, timeout=self._timeout,word=data,sign=True)
    
    def pitchfork_request(self,url,headers,params,word):
        if params:
            data = params
        else:
            data = ''
        return request.Request(url=url, headers=headers, data=data, verb=self._verb, timeout=self._timeout,word=word,sign=False)        

    async def sniper_mode(self):
        self._queue = asyncio.Queue()
        await self._queue.put(self.original_request())
        for w in self._words:
            fp = open(w,encoding='utf-8')
            content = fp.read().splitlines()
            for k,v in self._record.items():
                if k == 'url':
                    mlen = len(v.keys())
                    if mlen > 1:
                        for i in range(0,mlen):
                            tags = '$' + str(i) + '$'
                            self._tag = tags
                            self._url = self._orireq['url'].replace(v[tags],tags)
                            for w in content:
                                request = self.rebuild_request(w,self._url)
                                await self._queue.put(request)
                if k == "headers":
                    mlen = len(v.keys())
                    if mlen > 1:
                        for i in range(0,mlen):
                            tags = '$' + str(i) + '$'
                            self._tag = tags
                            headers = {k : v for (k,v) in self._orireq['headers'].items()}
                            headers = tools.replace_tag(headers,v[tags],tags)
                            self._headers = headers
                            for w in content:
                                request = self.rebuild_request(w,self._url)
                                await self._queue.put(request)
                if k == "params":
                    mlen = len(v.keys())
                    execs = []
                    if mlen > 1:
                        for i in range(0,mlen):
                            tags = '$' + str(i) + '$'
                            self._tag = tags
                            data = self._orireq['params']
                            self._data = data.replace(v[tags],tags)
                            for w in content:
                                request = self.rebuild_request(w,self._url)
                                await self._queue.put(request)
                        self._qsize = self._queue.qsize()
            fp.close()
    def yield_cluster_comb(self):
        tagFp = {}
        fps = []
        tags = []
        for w in self._words:
            id,wf = w.split("$")
            tag = '$' + id + '$'
            tagFp[tag] = wf
        for k,v in tagFp.items():
            tags.append(k)
            fps.append(v)
        readers = [open(f,encoding="utf-8") for f in fps]
        for lines in itertools.product(*readers):
            yield [l.replace("\n",'') for l in lines],tags
    async def cluster_mode(self):
        self._queue = asyncio.Queue()
        await self._queue.put(self.original_request())
        for k,v in self._record.items():
            if k == 'url':
                for v,t in self.yield_cluster_comb():
                    url = self._url
                    for i in range(0,len(v)):
                        url = self.params_replace_tag(url,t[i],v[i])
                    print(url)
                    request = self.pitchfork_request(url,self._orireq['headers'],self._orireq['params'],str(v))
                    await self._queue.put(request)
            elif k == "params":
                for v,t in self.yield_cluster_comb():
                    data = self._data
                    for i in range(0,len(v)):
                        data = self.params_replace_tag(data,t[i],v[i])
                    request = self.pitchfork_request(self._orireq['url'],self._orireq['headers'],data,str(v))
                    await self._queue.put(request)
            elif k == "headers":
                for v,t in self.yield_cluster_comb():
                    headers = self._headers.copy()
                    for i in range(0,len(v)):
                        headers = tools.replace_tag(headers,t[i],v[i])
                    request = self.pitchfork_request(self._orireq['url'],headers,self._orireq['params'],str(v))
                    await self._queue.put(request)
    def yield_comb(self):
        tagFp = {}
        fps = []
        tags = []
        for w in self._words:
            id,wf = w.split("$")
            tag = '$' + id + '$'
            tagFp[tag] = wf
        for k,v in tagFp.items():
            tags.append(k)
            fps.append(v)
        readers = [open(f,encoding="utf-8") for f in fps]
        for lines in zip(*readers):
            yield [l.replace("\n",'') for l in lines],tags

    def params_replace_tag(self,params:str,word,tag="$"):
        rparams = ''
        if tools.judge_tag(params,word):
            rparams = params.replace(word,tag)
        return rparams

    async def pitchfork_mode(self):
        self._queue = asyncio.Queue()
        await self._queue.put(self.original_request())
        for k,v in self._record.items():
            if k == "url":
                for v,t in self.yield_comb():
                    url = self._url
                    for i in range(0,len(v)):
                        url = self.params_replace_tag(url,t[i],v[i])
                    request = self.pitchfork_request(url,self._orireq['headers'],self._orireq['params'],str(v))
                    await self._queue.put(request)
            elif k == "params":
                for v,t in self.yield_comb():
                    data = self._data
                    for i in range(0,len(v)):
                        data = self.params_replace_tag(data,t[i],v[i])
                    request = self.pitchfork_request(self._orireq['url'],self._orireq['headers'],data,str(v))
                    await self._queue.put(request)
            elif k == "headers":
                for v,t in self.yield_comb():
                    headers = self._headers.copy()
                    for i in range(0,len(v)):
                        headers = tools.replace_tag(headers,t[i],v[i])
                    request = self.pitchfork_request(self._orireq['url'],headers,self._orireq['params'],str(v))
                    await self._queue.put(request)                               
    async def fill_queue(self):
        self._queue = asyncio.Queue()
        await self._queue.put(self.original_request())
        if isinstance(self._words,list):
            for w in self._words:
                request = self.rebuild_request(w,self._url)
                await self._queue.put(request)
        else:
            fp = open(self._words,encoding="utf-8")
            for w in fp.read().splitlines():
                request = self.rebuild_request(w,self._url)
                await self._queue.put(request)
        self._qsize = self._queue.qsize()
    async def stop(self):
        while not self._queue.empty():
            await self._queue.pop()
        self._queue.task_done()
        self.loop.stop()

    async def handle_request(self,request):
        data = None
        try:
            data = await request.send()
        except aiohttp.InvalidURL as e:
            printer.warn("无效的url!")
            await self.stop()
        except aiohttp.ClientConnectionError as e:
            printer.warn("载荷:{}, 请求失败!".format(request._word))
        except Exception as e:
            printer.warn("载荷:{}, 请求失败!".format(request._word))
        return data
    async def trigger_coros(self):
        coros = (self.consumer() for _ in range(self._limit))
        task = self.loop.create_task(asyncio.gather(*coros))
        await self._queue.join()

    async def work(self):
        await self.fill_queue()
        await self.trigger_coros()
        #self.loop.stop()

    async def sniper_work(self):
        await self.sniper_mode()
        await self.trigger_coros()
        #self.loop.stop()

    async def pitchfork_work(self):
        await self.pitchfork_mode()
        await self.trigger_coros()
        #self.loop.stop()

    async def cluster_work(self):
        await self.cluster_mode()
        await self.trigger_coros()

    async def consumer(self):
        while True:
            request = await self._queue.get()
            await self.call_request(request)
    
    async def call_request(self,request):
        global comp
        self._start_time = time.time()
        cLength = ''
        cStatus = ''
        cTime = ''
        try:
            data = await self.handle_request(request)
            if data is None:
                printer.warn("响应信息为空!")
                return
            spent = int((time.time() - self._start_time) * 1000)
            resp = data['response']
            content = data['text']
            sign = data['isOrign']
            if sign:
                comp['resp'] = resp
                comp['text'] = content
                comp['spent'] = spent
            else:
                if comp:
                    nLength = len(content) + len(content.split(" ")) + len(content.splitlines())
                    oLength = len(comp['text']) + len(comp['text'].split(" ")) + len(comp['text'].splitlines())
                    if nLength - oLength != 0:
                        cLength = "长度不一致！"
                    if int(resp.status) - int(comp['resp'].status) != 0:
                        cStatus = "状态码不一致!"
                    if spent - comp['spent'] > 50 :
                        cTime = "慢了50s以上!"
                    elif spent - comp['spent'] < -50:
                        cTime = "快了50s以上!"
            if tools.is_matching(resp.status,content,hc=self._hc,ht=self._ht,st=self._st,sc=self._sc):
                lenContent = str(len(content)) + "-" + str(len(content.split(" "))) + "-" + str(len(content.splitlines()))
                if cLength:
                    printer.warn("载荷: {}, 长度: {}, 时间:{}, 状态码:{}, 对比:{}".format(request._word,str(lenContent),str(spent),str(resp.status),cLength))
                elif cStatus:
                    printer.warn("载荷: {}, 长度: {}, 时间:{}, 状态码:{}, 对比:{}".format(request._word,str(lenContent),str(spent),str(resp.status),cStatus))
                elif cTime:
                    printer.warn("载荷: {}, 长度: {}, 时间:{}, 状态码:{}, 对比:{}".format(request._word,str(lenContent),str(spent),str(resp.status),cTime))
                else:
                    printer.plus("载荷: {}, 长度: {}, 时间:{}, 状态码:{}".format(request._word,str(lenContent),str(spent),str(resp.status)))
            if self._delay > 0:
                await asyncio.sleep(self._delay)
        except Exception as e:
            print(e)
        self._queue.task_done()
    def loop(self):
        self.loop = asyncio.get_event_loop()
        self.loop.set_exception_handler(handler_exception)
        try:
            if self._mode == 1:
                #self.loop.run_until_complete(asyncio.wait_for(self.work(),timeout=50))
                self.loop.run_until_complete(self.work())
            elif self._mode == 2:
                self.loop.run_until_complete(self.sniper_work())
            elif self._mode == 3:
                self.loop.run_until_complete(self.pitchfork_work())
            elif self._mode == 4:
                self.loop.run_until_complete(self.cluster_work())   
        except KeyboardInterrupt as e:
            print("终止!")
        except Exception as e:
            print(e)
            traceback.print_exc()